Python TopologicalSorter源码解析

TopologicalSorter 常用于工作流中,用于定义每个节点的执行顺序。

例如上图中,定义了5个任务,希望在 HTTP Activate 执行结束后,能获取到下一次要执行的节点,Command Activate 和 Download Activate。这样既保证了调用顺序,同时也可以将 Command Activate 和 Download Activate 同时执行,以提高整个流程的执行效率。

Python 实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
__all__ = ["TopologicalSorter", "CycleError"]

_NODE_OUT = -1
_NODE_DONE = -2


class _NodeInfo:
__slots__ = "node", "npredecessors", "successors"

def __init__(self, node):
self.node = node

# Number of predecessors, generally >= 0. When this value falls to 0,
# and is returned by get_ready(), this is set to _NODE_OUT and when the
# node is marked done by a call to done(), set to _NODE_DONE.
self.npredecessors = 0

# List of successor nodes. The list can contain duplicated elements as
# long as they're all reflected in the successor's npredecessors attribute.
self.successors = []


class CycleError(ValueError):
pass


class TopologicalSorter:
def __init__(self, graph=None):
# 节点映射map
self._node2info = {}
# 就绪节点列表
self._ready_nodes = None
#
self._npassedout = 0
self._nfinished = 0

if graph is not None:
for node, predecessors in graph.items():
self.add(node, *predecessors)

def _get_nodeinfo(self, node):
if (result := self._node2info.get(node)) is None:
self._node2info[node] = result = _NodeInfo(node)
return result

def add(self, node, *predecessors):
"""
node:当前节点
predecessors:依赖的节点
"""
if self._ready_nodes is not None:
raise ValueError("Nodes cannot be added after a call to prepare()")

# 获取节点信息
nodeinfo = self._get_nodeinfo(node)
# 更新节点依赖的节点数量
nodeinfo.npredecessors += len(predecessors)

# 创建依赖节点,并绑定到当前节点的 successors 列表中
for pred in predecessors:
pred_info = self._get_nodeinfo(pred)
pred_info.successors.append(node)

def prepare(self):
if self._ready_nodes is not None:
raise ValueError("cannot prepare() more than once")

# 获取就绪节点:在节点 map 中遍历所有节点。将所有节点中 npredecessors 为 0 的节点的 node 取出。
self._ready_nodes = [
i.node for i in self._node2info.values() if i.npredecessors == 0
]
# 判断是否成环
cycle = self._find_cycle()
if cycle:
raise CycleError(f"nodes are in a cycle", cycle)

def get_ready(self):
if self._ready_nodes is None:
raise ValueError("prepare() must be called first")

# 获取就绪节点
result = tuple(self._ready_nodes)
n2i = self._node2info
# 将当前就绪节点的 npredecessors 的值设置为 -1
for node in result:
n2i[node].npredecessors = _NODE_OUT

# 清除就绪节点列表,方便下次在塞入就绪节点
self._ready_nodes.clear()
# _npassedout 计数添加当前就绪节点数量
self._npassedout += len(result)
# 返回就绪节点
return result

def is_active(self):
"""如果可以取得更多进展,则返回' True ',否则返回' False '。
如果循环没有阻塞分辨率,或者仍然有节点准备就绪,“get_ready”还没有返回,
或者标记为“done”的节点数量小于“get_ready”返回的节点数量,则可以取得进展。
如果之前没有调用“prepare”而被调用,将引发ValueError。
"""
if self._ready_nodes is None:
raise ValueError("prepare() must be called first")
return self._nfinished < self._npassedout or bool(self._ready_nodes)

def __bool__(self):
return self.is_active()

def done(self, *nodes):
if self._ready_nodes is None:
raise ValueError("prepare() must be called first")

n2i = self._node2info

# 遍历传入要 done 的节点列表
for node in nodes:

# 判断节点是否存在
if (nodeinfo := n2i.get(node)) is None:
raise ValueError(f"node {node!r} was not added using add()")

# 确保当前要被 done 的节点已经被标记为 _NODE_OUT(-1) 了。
stat = nodeinfo.npredecessors
if stat != _NODE_OUT:
if stat >= 0:
raise ValueError(
f"node {node!r} was not passed out (still not ready)"
)
# 标记为 -2 的表示已经 done 过了
elif stat == _NODE_DONE:
raise ValueError(f"node {node!r} was already marked done")
else:
assert False, f"node {node!r}: unknown status {stat}"

# 将当前节点 npredecessors 标为 _NODE_DONE -2,表示已经被 done 了
nodeinfo.npredecessors = _NODE_DONE

# 获取当前被 done 节点的依赖的节点的列表
for successor in nodeinfo.successors:
# 依赖节点的信息
successor_info = n2i[successor]
# 当前节点被 done 了,所以,所有依赖当前节点的 npredecessors 要 -1
successor_info.npredecessors -= 1
if successor_info.npredecessors == 0:
# 如果减到0了,就将节点添加到 就绪节点列表中。表示这个节点没有依赖的节点了。下次 get_ready 可以返回了
self._ready_nodes.append(successor)
# 节点 done 计数 +1
self._nfinished += 1

def _find_cycle(self):
n2i = self._node2info # 所有节点的 map
stack = []
itstack = []
seen = set()
node2stacki = {}

# 遍历所有节点
for node in n2i:
if node in seen:
continue

while True:
if node in seen:
# If we have seen already the node and is in the
# current stack we have found a cycle.
if node in node2stacki:
return stack[node2stacki[node] :] + [node]
# else go on to get next successor
else:
seen.add(node)
itstack.append(iter(n2i[node].successors).__next__)
node2stacki[node] = len(stack)
stack.append(node)

# Backtrack to the topmost stack entry with
# at least another successor.
while stack:
try:
print(itstack[-1])
node = itstack[-1]()
break
except StopIteration:
del node2stacki[stack.pop()]
itstack.pop()
else:
break
return None

def static_order(self):
"""
按拓扑顺序返回节点的迭代对象
"""
self.prepare()
while self.is_active():
node_group = self.get_ready()
yield from node_group
self.done(*node_group)


if __name__ == '__main__':
graph = {
"C": ["A"],
"B": ["A", "1"],
"D": ["B", "C"],
"E": ["B"],
"F": ["D", "E"],
}
ts = TopologicalSorter(graph)
ts.prepare()

read_node_flag = True
while read_node_flag:
read_node = ts.get_ready()
if not read_node:
read_node_flag = False
break
print(read_node)
for x in read_node:
ts.done(x)

Go 实现

在 Python 的基础上,模仿了一个 Go 语言的版本

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
package utility

import (
"errors"
)

const (
NodeOut = -1
NodeDone = -2
)

type TopologicalNode struct {
N ActivateInterface // 节点实例
Predecessors int // 依赖节点的数量,只有为 0 才能被 getReady 获取到
Successors []ActivateInterface // 依赖节点的列表
}

type TopologicalSorter struct {
ReadyNodes []ActivateInterface // 就绪 nodes
NodeMap map[ActivateInterface]*TopologicalNode // Node Map
}

func (t *TopologicalSorter) hasCycle(graph map[ActivateInterface][]ActivateInterface) bool {
visited := make(map[ActivateInterface]bool)
stack := make(map[ActivateInterface]bool)

var dfs func(node ActivateInterface) bool
dfs = func(node ActivateInterface) bool {
visited[node] = true
stack[node] = true

for _, neighbor := range graph[node] {
if !visited[neighbor] {
if dfs(neighbor) {
return true
}
} else if stack[neighbor] {
return true
}
}

stack[node] = false
return false
}

for node := range graph {
if !visited[node] {
if dfs(node) {
return true
}
}
}
return false
}

func (t *TopologicalSorter) New(graph map[ActivateInterface][]ActivateInterface) error {
if len(t.ReadyNodes) > 0 {
return errors.New("nodes cannot be added after a call to Prepare()")
}
for node, predecessors := range graph {
t.addNode(node, predecessors...)
}
if t.hasCycle(graph) {
return errors.New("nodes are in a cycle")
}
return nil
}

// Done 处理完成后需要调用 Done 方法,将其从就绪中移除
func (t *TopologicalSorter) Done(node ActivateInterface) error {
n := t.getNode(node)

//确保当前要被 done 的节点已经被标记为 _NODE_OUT(-1) 了。
if n.Predecessors != NodeOut {
if n.Predecessors >= 0 {
return errors.New("node was not passed out (still not ready)")
} else if n.Predecessors == NodeDone {
return errors.New("node was already marked done")
} else {
return errors.New("unknown status")
}
}
// 将本身置为过期
n.Predecessors = NodeDone
for _, successor := range n.Successors {
successorNode := t.getNode(successor)
// 依赖减-1
successorNode.Predecessors -= 1
if successorNode.Predecessors == 0 {
// 没有依赖的节点,加入 ReadyNodes
t.ReadyNodes = append(t.ReadyNodes, successor)
}
}
return nil
}

// Prepare 运行 GetRead 前需要运行 Prepare。Prepare 添加就绪节点
func (t *TopologicalSorter) Prepare() error {
if len(t.ReadyNodes) > 0 {
return errors.New("cannot prepare() more than once")
}
for _, v := range t.NodeMap {
// 等于0才加入就绪列表,小于0说明过期了。大于0说明还有前置依赖节点没有执行。
if v.Predecessors == 0 {
t.ReadyNodes = append(t.ReadyNodes, v.N)
}
}
return nil
}

// GetReady 获取就绪节点
func (t *TopologicalSorter) GetReady() (data []ActivateInterface) {
data = t.ReadyNodes
// 将当前就绪节点的 Predecessors 的值设置为 -1
for _, d := range data {
t.NodeMap[d].Predecessors = NodeOut
}
// 每次获取 Read 节点后,需要清除 ReadyNodes
t.ReadyNodes = t.ReadyNodes[:0:0]
return data
}

// 获取节点,有缓存则直接返回。没有则创建一个新的
func (t *TopologicalSorter) getNode(node ActivateInterface) (n *TopologicalNode) {
n, ok := t.NodeMap[node]
if !ok {
n = &TopologicalNode{N: node, Predecessors: 0, Successors: []ActivateInterface{}}
t.NodeMap[node] = n
}
return n
}

// 新增节点,并设置节点的 Predecessors 和 Successors
func (t *TopologicalSorter) addNode(node ActivateInterface, predecessors ...ActivateInterface) {
nodeInfo := t.getNode(node)
nodeInfo.Predecessors += len(predecessors)
for _, p := range predecessors {
cacheNodeInfo := t.getNode(p)
cacheNodeInfo.Successors = append(cacheNodeInfo.Successors, node)
}
}